import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import torch


fig_dpi = 300


def plot_pca_variance_explained(pca_model_actor, pca_model_critic):
    """
    Plot the variance explained by each principal component and the cumulative variance explained
    for the actor and critic PCA models.
    Args:
        pca_model_actor: PCA model for the actor
        pca_model_critic: PCA model for the critic
    returns:
        fig: matplotlib figure
    """

    fig, axs = plt.subplots(
        nrows=2, ncols=2,
        figsize=(6, 4), dpi=fig_dpi
    )

    ax = axs[0,0]
    ax.plot(pca_model_actor.explained_variance_ratio_)
    ax.set_title('Actor var. explained per PC')
    ax.set_xlabel('PC')
    ax.set_ylabel('var. exp.')

    ax = axs[0,1]
    ax.plot(np.cumsum(pca_model_actor.explained_variance_ratio_))
    ax.set_title('Actor cumulative var. explained')
    ax.set_xlabel('PC')
    ax.set_ylabel('cumulative var. exp.')
    ax.set_ylim(0, 1)


    ax = axs[1,0]
    ax.plot(pca_model_critic.explained_variance_ratio_)
    ax.set_title('Critic var. explained per PC')
    ax.set_xlabel('PC')
    ax.set_ylabel('var. exp.')

    ax = axs[1,1]
    ax.plot(np.cumsum(pca_model_critic.explained_variance_ratio_))
    ax.set_title('Critic cumulative var. explained')
    ax.set_xlabel('PC')
    ax.set_ylabel('cumulative var. exp.')
    ax.set_ylim(0, 1)

    fig.tight_layout()

    return fig


def plot_traj_in_PC_space(
    transformed_RNN_hidden_states_background,
    background_coloring_array,
    background_colorbar_label,
    background_coloring_minmax,
    num_pcs2plot,
    transformed_RNN_hidden_states_example_run=None,
    transformed_fixed_points=None,
    fixed_points_coloring_array=None,
    fixed_points_condition=None
):
    """
    Plot the transformed hidden states in the PCA space.
    Args:
        transformed_RNN_hidden_states_background: Transformed RNN hidden states for background
        transformed_RNN_hidden_states_example_run: Transformed RNN hidden states for an example run
        coloring_array: Array to color each state
        num_pcs2plot: Number of principal components to plot
        
    Returns:
        fig: matplotlib figure
    """
    n_axs = num_pcs2plot * (num_pcs2plot-1) / 2
    n_cols = num_pcs2plot
    n_rows = int(n_axs/ n_cols) + (n_axs% n_cols > 0)

    fig, axs = plt.subplots(
        nrows=n_rows, ncols=n_cols,
        figsize=(6*n_cols, 4*n_rows), dpi=fig_dpi
    )

    ax_id = 0
    for pc_x in range(num_pcs2plot):
        for pc_y in range(num_pcs2plot):
            if pc_x < pc_y:
                ax = axs[int(ax_id/n_cols), ax_id%n_cols]
                
                # set the title
                if transformed_fixed_points is not None:
                    ax.set_title(f'Fixed points under {fixed_points_condition}', fontsize=18)
                else:
                    ax.set_title(f'Neural space trajectoy', fontsize=18)
                
                # Plot the background
                scatter = ax.scatter(
                    transformed_RNN_hidden_states_background[:, :, pc_x],
                    transformed_RNN_hidden_states_background[:, :, pc_y],
                    s=2.5,
                    c=background_coloring_array, cmap=cm.coolwarm,
                    vmin=background_coloring_minmax[0], 
                    vmax=background_coloring_minmax[1],
                    alpha=0.9
                )
                ax.set_xlabel(f'PC {pc_x}', fontsize=18)
                ax.set_ylabel(f'PC {pc_y}', fontsize=18)
                ax.tick_params(axis='both', which='major', labelsize=12)
                fig.colorbar(
                    scatter, ax=ax,
                    label=background_colorbar_label
                )
                
                # Plot the trajectory of the example run
                if transformed_RNN_hidden_states_example_run is not None:
                    ax.plot(
                        transformed_RNN_hidden_states_example_run[:, pc_x],
                        transformed_RNN_hidden_states_example_run[:, pc_y],
                        color='k', lw=1, alpha=0.9
                    )
                    exameple_run_len = transformed_RNN_hidden_states_example_run.shape[0]
                    example_run = ax.scatter(
                        transformed_RNN_hidden_states_example_run[:, pc_x],
                        transformed_RNN_hidden_states_example_run[:, pc_y],
                        s=4,
                        c=np.arange(exameple_run_len), cmap=cm.copper,
                        vmin=0, vmax=exameple_run_len
                    )  # color coded by time step
                    fig.colorbar(
                        example_run, ax=ax,
                        label='Time step'
                    )

                # Plot the fixed points
                if transformed_fixed_points is not None:
                    scatter_fp = ax.scatter(
                        transformed_fixed_points[:, pc_x],
                        transformed_fixed_points[:, pc_y],
                        marker='o', s=5,
                        c=fixed_points_coloring_array, cmap=cm.gray,
                        vmin=1e-13, vmax=10, norm='log',
                        alpha=0.6
                    )  # color coded by q values
                    fig.colorbar(
                        scatter_fp, ax=ax,
                        label='Kinetic Energy (q value)'
                    )

                ax_id += 1

    fig.tight_layout()
    return fig


def plot_traj_in_3D_PC_space(
    transformed_RNN_hidden_states_background,
    background_coloring_array,
    background_colorbar_label,
    background_coloring_minmax,
    pc_x, pc_y, pc_z,
    transformed_RNN_hidden_states_example_run
):
    """
    Plot the transformed hidden states in the 3D PCA space.
    Args:
        transformed_RNN_hidden_states_background: Transformed RNN hidden states for background
        transformed_RNN_hidden_states_example_run: Transformed RNN hidden states for an example run
        coloring_array: Array to color each state
        pc_x, pc_y, pc_z: Principal components to plot

    Returns:
        fig: matplotlib figure
    """
    fig = plt.figure(figsize=(15,6), dpi=fig_dpi)
    ax = fig.add_subplot(projection='3d')
    ax.set_title('Neural space trajectory', fontsize=18)
    
    # Plot the background
    scatter = ax.scatter(
        transformed_RNN_hidden_states_background[:, :, pc_x],
        transformed_RNN_hidden_states_background[:, :, pc_y],
        transformed_RNN_hidden_states_background[:, :, pc_z],
        s=3.0, alpha=0.8,
        c=background_coloring_array, cmap=cm.coolwarm,
        vmin=background_coloring_minmax[0], 
        vmax=background_coloring_minmax[1],
        zorder=10
    )
    ax.set_xlabel(f'PC {pc_x}', fontsize=18)
    ax.set_ylabel(f'PC {pc_y}', fontsize=18)
    ax.set_zlabel(f'PC {pc_z}', fontsize=18)
    fig.colorbar(
        scatter, ax=ax,
        label=background_colorbar_label
    )

    # Plot the trajectory of the example run
    if transformed_RNN_hidden_states_example_run is not None:
        ax.plot(
            transformed_RNN_hidden_states_example_run[:, pc_x],
            transformed_RNN_hidden_states_example_run[:, pc_y],
            transformed_RNN_hidden_states_example_run[:, pc_z],
            color='k', lw=1, alpha=0.95, 
            zorder=50
        )
        exameple_run_len = transformed_RNN_hidden_states_example_run.shape[0]
        exameple_run = ax.scatter(
            transformed_RNN_hidden_states_example_run[:, pc_x],
            transformed_RNN_hidden_states_example_run[:, pc_y],
            transformed_RNN_hidden_states_example_run[:, pc_z],
            s=10,
            c=np.arange(exameple_run_len), cmap=cm.copper,
            vmin=0, vmax=exameple_run_len,
            zorder=60
        )  # color coded by time step
        fig.colorbar(
            exameple_run, ax=ax,
            label='Time step'
        )

    fig.tight_layout()
    return fig


def find_fixed_points_parallel(
    model,
    args,
    hidden_states_trajs,  # actor_hidden_states_test.shape (3, 400, 64)
    # inputs
    curr_states,
    prev_actions,
    prev_rewards,
    critic_prev_hidden_states,
    n_inits=800,
    q_threshold=0.001,
    noise_scale=0.1,
    learning_rate=0.001,
    max_iters=6000,
    save_interval=200,
    print_interval=1000,
    device='cpu',
    seed=None
):

    # output
    x_init_list, x_star_list = [], []
    q_star_list, dq_list = [], []
    x_star_jac_list = []
    n_iters_list = []
    save_iters = np.linspace(save_interval, max_iters, int(max_iters/save_interval), dtype=int).tolist()

    # set the seed
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)

    model = model.to(device)
    hidden_states_trajs = hidden_states_trajs.to(device)
    curr_states = curr_states.to(device)
    prev_actions = prev_actions.to(device)
    prev_rewards = prev_rewards.to(device)
    critic_prev_hidden_states = critic_prev_hidden_states.to(device)

    # Freeze for parameters in the recurrent network
    for param in model.parameters():
        param.requires_grad = False
    
    # initialize random points along the trajectory
    n_runs, n_trials, hidden_dims = hidden_states_trajs.shape
    hidden_state_pts = hidden_states_trajs.reshape(-1, hidden_dims)
    init_idx = torch.randint(n_runs*n_trials, size=(n_inits,), device=device)
    # select the initial hidden states
    states = hidden_state_pts[init_idx, :]
    states = states.reshape(-1, n_inits, hidden_dims)
    
    # # or speculated fixed points locations: e.g. block transition
    # transition_trial_ids = np.cumsum(info['block_lens'])
    # init_points_trial_ids = []
    # for trial_id in transition_trial_ids[:-1]:
    #     init_points_trial_ids.append(np.arange(trial_id-5, trial_id+5))
    # init_points_trial_ids = np.array(init_points_trial_ids).flatten()
    # for init_point_trial_id in init_points_trial_ids:
    #     trial_points = actor_hidden_states_test[-1][init_point_trial_id]
    #     traj_points_tiled = np.tile(trial_points, (batch_size, 1)).reshape(1, batch_size, args.rnn_hidden_dim)

    # add Gaussian noise to the sampled points
    states = states + noise_scale * torch.randn_like(states, device=device)

    # require gradients for the states
    states = states.detach()
    initial_states = states.detach().clone().cpu().numpy()
    states.requires_grad = True

    # create the optimizer
    opt = torch.optim.Adam([states], lr=learning_rate)

    # run the optimization
    iter_count = 1
    q_prev = torch.full((n_inits,), float("nan"), device=device)
    while True:
        # Compute q and dq for the current states
        _, _, new_states, _ = model(
            curr_states, prev_actions, prev_rewards,
            states, critic_prev_hidden_states
        )
        q = 0.5 * torch.sum((new_states.squeeze() - states.squeeze()) ** 2, dim=1)
        dq = torch.abs(q - q_prev)
        q_scalar = torch.mean(q)

        # Backpropagate gradients and optimize
        q_scalar.backward()
        opt.step()
        opt.zero_grad()

        # Detach evaluation tensors
        # q_np = q.cpu().detach().numpy()
        # dq_np = dq.cpu().detach().numpy()
        # Report progress
        # if iter_count % print_interval == 0:
        #     mean_q, std_q = np.mean(q_np), np.std(q_np)
        #     mean_dq, std_dq = np.mean(dq_np), np.std(dq_np)
        #     print(f"\nIteration {iter_count}/{max_iters}")
        #     print(f"q = {mean_q:.2E} +/- {std_q:.2E}")
        #     print(f"dq = {mean_dq:.2E} +/- {std_dq:.2E}")

        # save
        if iter_count in save_iters:
            # print(f'saving: {iter_count}')
            # compute jacobian

            # Compute the Jacobian for each fixed point
            def J_func(model, inputs_1, x):
                # This function takes both the additional inputs and the state.
                curr_state_1, prev_action_1, prev_reward_1, critic_prev_hidden_state_1 = inputs_1
                _, _, x_one_step, _ = model(
                    curr_state_1, prev_action_1, prev_reward_1,
                    x, critic_prev_hidden_state_1
                )
                return x_one_step.squeeze()
            
            def compute_jacobians_func(model, x_batch):
                all_J = []

                for i in range(x_batch.size(1)):
                    curr_state_1 = curr_states[:, 0, :]
                    prev_action_1 = prev_actions[:, 0, :]
                    prev_reward_1 = prev_rewards[:, 0, :]
                    critic_prev_hidden_state_1 = critic_prev_hidden_states[:, 0, :]

                    inputs_1 = (curr_state_1, prev_action_1, prev_reward_1, critic_prev_hidden_state_1)
                    single_x = x_batch[:, i, :]

                    J = torch.autograd.functional.jacobian(
                        lambda x: J_func(model, inputs_1, x), single_x
                    )
                    all_J.append(J.squeeze())

                return all_J
            
            fps4jac = states.detach().clone()
            fps4jac.requires_grad = True

            all_J = compute_jacobians_func(model, fps4jac)
            dFdx = torch.stack(all_J).detach().clone().cpu().numpy()
            
            x_init_list.append(initial_states)
            x_star_list.append(states.detach().clone().cpu().numpy().squeeze())
            q_star_list.append(q.detach().clone().cpu().numpy())
            dq_list.append(dq.detach().clone().cpu().numpy())
            x_star_jac_list.append(dFdx.T)
            n_iters_list.append(np.full_like(q.cpu().detach().numpy(), iter_count))
            
        # Check termination criteria
        if iter_count + 1 > max_iters:
            print("Maximum iteration count reached. Terminating.")
            break
        
        # for next iter
        q_prev = q
        iter_count += 1

    x_init_arr = np.array(x_init_list).reshape(-1, hidden_states_trajs.shape[-1])
    x_star_arr = np.array(x_star_list).reshape(-1, hidden_states_trajs.shape[-1])
    q_star_arr = np.array(q_star_list).reshape(-1)
    dq_arr = np.array(dq_list).reshape(-1)
    x_star_jac_arr = np.array(x_star_jac_list).reshape(-1, hidden_states_trajs.shape[-1]*hidden_states_trajs.shape[-1])
    n_iters_arr = np.array(n_iters_list).reshape(-1)

    df_fixed_points = pd.DataFrame({
        'x_init': list(x_init_arr),
        'x_star': list(x_star_arr),
        'q_star': list(q_star_arr),
        'dq': list(dq_arr),
        'x_star_jac': list(x_star_jac_arr),
        'n_iters': list(n_iters_arr)
    })
    # print(f"Found {len(xstar)} unique fixed points.")

    return df_fixed_points


def find_fixed_points_sequential(
    model,
    args,
    hidden_states_trajs,  # actor_hidden_states_test.shape (3, 400, 64)
    # inputs
    curr_states,
    prev_actions,
    prev_rewards,
    critic_prev_hidden_states,
    n_inits=800,
    q_threshold=0.001,
    noise_scale=0.1,
    learning_rate=0.001,
    max_iters=6000,
    save_interval=1000,
    print_interval=1000,
    device='cpu',
    seed=None
):
    # output
    x_init_list, x_star_list = [], []
    q_star_list, dq_list = [], []
    n_iters_list = []
    save_iters = np.linspace(save_interval, max_iters, int(max_iters/save_interval), dtype=int).tolist()

    # set the seed
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)

    model = model.to(device)
    hidden_states_trajs = hidden_states_trajs.to(device)
    curr_states = curr_states.to(device)
    prev_actions = prev_actions.to(device)
    prev_rewards = prev_rewards.to(device)
    critic_prev_hidden_states = critic_prev_hidden_states.to(device)

    # Freeze for parameters in the recurrent network
    for param in model.parameters():
        param.requires_grad = False
    
    # initialize random points along the trajectory
    n_runs, n_trials, hidden_dims = hidden_states_trajs.shape
    hidden_state_pts = hidden_states_trajs.reshape(-1, hidden_dims)
    init_idx = torch.randint(n_runs*n_trials, size=(n_inits,), device=device)
    # select the initial hidden states
    states = hidden_state_pts[init_idx]
    states = states.reshape(-1, n_inits, hidden_dims)
    
    # # or speculated fixed points locations: e.g. block transition
    # transition_trial_ids = np.cumsum(info['block_lens'])
    # init_points_trial_ids = []
    # for trial_id in transition_trial_ids[:-1]:
    #     init_points_trial_ids.append(np.arange(trial_id-5, trial_id+5))
    # init_points_trial_ids = np.array(init_points_trial_ids).flatten()
    # for init_point_trial_id in init_points_trial_ids:
    #     trial_points = actor_hidden_states_test[-1][init_point_trial_id]
    #     traj_points_tiled = np.tile(trial_points, (batch_size, 1)).reshape(1, batch_size, args.rnn_hidden_dim)

    # add Gaussian noise to the sampled points
    states = states + noise_scale * torch.randn_like(states, device=device)

    for idx in range(n_inits):

        # require gradients for the states
        state = states[:, idx, :].reshape(-1, 1, hidden_dims).detach()
        initial_state = state.detach().clone().cpu().numpy()
        state.requires_grad = True

        # create the optimizer
        opt = torch.optim.Adam([state], lr=learning_rate)

        # run the optimization
        iter_count = 1
        q_prev = torch.full((1,), float("nan"), device=device)
        while True:
            # Compute q and dq for the current states
            _, _, new_state, _ = model(
                curr_states, prev_actions, prev_rewards,
                state, critic_prev_hidden_states
            )
            q = 0.5 * torch.sum((new_state.squeeze() - state.squeeze()) ** 2, dim=0)
            dq = torch.abs(q - q_prev)
            q_scalar = torch.mean(q)

            # Backpropagate gradients and optimize
            q_scalar.backward()
            opt.step()
            opt.zero_grad()

            # Detach evaluation tensors
            q_np = q.cpu().detach().numpy()
            dq_np = dq.cpu().detach().numpy()
            # Report progress
            # if iter_count % print_interval == 0:
            #     mean_q, std_q = np.mean(q_np), np.std(q_np)
            #     mean_dq, std_dq = np.mean(dq_np), np.std(dq_np)
            #     print(f"\nIteration {iter_count}/{max_iters}")
            #     print(f"q = {mean_q:.2E} +/- {std_q:.2E}")
            #     print(f"dq = {mean_dq:.2E} +/- {std_dq:.2E}")

            # save
            if iter_count in save_iters:
                # print(f'saving: {iter_count}')
                x_init_list.append(initial_state)
                x_star_list.append(state.cpu().detach().numpy().squeeze())
                q_star_list.append(q.cpu().detach().numpy())
                dq_list.append(dq.cpu().detach().numpy())
                n_iters_list.append(np.full_like(q.cpu().detach().numpy(), iter_count))

            # Check termination criteria
            if iter_count + 1 > max_iters:
                print("Maximum iteration count reached. Terminating.")
                break

            q_prev = q
            iter_count += 1
        
    x_init_arr = np.array(x_init_list).reshape(-1, hidden_states_trajs.shape[-1])
    x_star_arr = np.array(x_star_list).reshape(-1, hidden_states_trajs.shape[-1])
    q_star_arr = np.array(q_star_list).reshape(-1)
    dq_arr = np.array(dq_list).reshape(-1)
    n_iters_arr = np.array(n_iters_list).reshape(-1)

    df_fixed_points = pd.DataFrame({
        'x_init': list(x_init_arr),
        'x_star': list(x_star_arr),
        'q_star': list(q_star_arr),
        'dq': list(dq_arr),
        'n_iters': list(n_iters_arr)
    })
    # print(f"Found {len(xstar)} unique fixed points.")

    return df_fixed_points


def find_fixed_points_yang(
    info,
    model,
    args,
    hidden_states_trajs,  # actor_hidden_states_test.shape (3, 400, 64)
    # inputs
    curr_states,
    prev_actions,
    prev_rewards,
    critic_prev_hidden_states,
    batch_size=1,
    n_inits=800,
    q_threshold=0.001,
    noise_scale=0.1,
    learning_rate=0.001,
    max_iters=6000,
    save_interval=1000,
    print_interval=1000,
    device='cpu',
    seed=None
):
    # based on Yang implementation

    # output
    x_init_list, x_star_list = [], []
    q_star_list, dq_list = [], []
    n_iters_list = []
    save_iters = np.linspace(save_interval, max_iters, int(max_iters/save_interval), dtype=int).tolist()

    # set the seed
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)

    model = model.to(device)
    hidden_states_trajs = hidden_states_trajs.to(device)
    curr_states = curr_states.to(device)
    prev_actions = prev_actions.to(device)
    prev_rewards = prev_rewards.to(device)
    critic_prev_hidden_states = critic_prev_hidden_states.to(device)

    # Freeze for parameters in the recurrent network
    for param in model.parameters():
        param.requires_grad = False

    # initialize random points along the trajectory
    # n_runs, n_trials, hidden_dims = hidden_states_trajs.shape
    # hidden_state_pts = hidden_states_trajs.reshape(-1, hidden_dims)
    # init_idx = torch.randint(n_runs*n_trials, size=(n_inits,), device=device)
    # # select the initial hidden states
    # states = hidden_state_pts[init_idx]
    # states = states.reshape(-1, n_inits, hidden_dims)
    # or speculated fixed points locations
    transition_trial_ids = np.cumsum(info['block_lens'])
    init_points_trial_ids = []
    for trial_id in transition_trial_ids[:-1]:
        init_points_trial_ids.append(np.arange(trial_id-5, trial_id+5))
    init_points_trial_ids = np.array(init_points_trial_ids).flatten()

    # add Gaussian noise to the sampled points
    # states = states + noise_scale * torch.randn_like(states, device=device)


    for init_point_trial_id in init_points_trial_ids:
        trial_points = hidden_states_trajs[-1, init_point_trial_id, :]
        traj_points_tiled = np.tile(trial_points, (batch_size, 1)).reshape(1, batch_size, args.rnn_hidden_dim)
        state = torch.tensor(
            traj_points_tiled + np.random.rand(1, batch_size, args.rnn_hidden_dim),
            requires_grad=True, dtype=torch.float32
        ).to(device)
        initial_state = state.detach().clone().cpu().numpy()
    # for idx in range(n_inits):
    #     # require gradients for the states
    #     state = states[:, idx, :].reshape(-1, 1, hidden_dims).detach()
    #     initial_state = state.detach().clone().cpu().numpy()
    #     state.requires_grad = True

        # create the optimizer
        opt = torch.optim.Adam([state], lr=learning_rate)
        criterion = torch.nn.MSELoss()

        # run optimization
        iter_count = 1
        q_prev = np.nan
        while True:
            opt.zero_grad()   # zero the gradient buffers
            
            # Take the one-step recurrent function from the trained network
            _, _, new_state, _ = model(
                curr_states, prev_actions, prev_rewards,
                state, critic_prev_hidden_states
            )
            loss = criterion(
                new_state, 
                state
            )
            loss.backward()
            opt.step()    # Does the update

            q = 0.5 * loss.cpu().detach().numpy()
            dq = np.abs(q - q_prev)
            
            # save
            if iter_count in save_iters:
                # print(f'saving: {iter_count}')
                x_init_list.append(initial_state)
                x_star_list.append(state.cpu().detach().clone().numpy().squeeze())
                q_star_list.append(q)
                dq_list.append(dq)
                n_iters_list.append(np.full_like(q, iter_count))

            # Check termination criteria
            if iter_count + 1 > max_iters:
                print("Maximum iteration count reached. Terminating.")
                break

            q_prev = q
            iter_count += 1
            
    x_init_arr = np.array(x_init_list).reshape(-1, hidden_states_trajs.shape[-1])
    x_star_arr = np.array(x_star_list).reshape(-1, hidden_states_trajs.shape[-1])
    q_star_arr = np.array(q_star_list).reshape(-1)
    dq_arr = np.array(dq_list).reshape(-1)
    n_iters_arr = np.array(n_iters_list).reshape(-1)

    df_fixed_points = pd.DataFrame({
        'x_init': list(x_init_arr),
        'x_star': list(x_star_arr),
        'q_star': list(q_star_arr),
        'dq': list(dq_arr),
        'n_iters': list(n_iters_arr)
    })
    # print(f"Found {len(xstar)} unique fixed points.")

    return df_fixed_points